[gpt-oss] triton kernel mxfp4#22421
Conversation
Signed-off-by: <zyy1102000@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces support for mxfp4 quantization on Hopper GPUs by integrating a new Triton kernel for MoE layers. The changes include adding the kernel wrappers, modifying the mxfp4 quantization path to use it, and adding corresponding tests. The implementation looks solid, but I have two high-level concerns. First, the number of warps for the Triton kernel is configured statically based on an environment variable, which might not be optimal or correct for dynamic batch sizes at runtime. Second, a utility function modifies a global configuration flag, which is a risky pattern that could lead to hard-to-debug side effects. Addressing these points would improve the robustness and maintainability of this new feature.
| # FIXME warp need to be adjusted based on batch size | ||
| # only apply to batched mode | ||
| if self.moe.use_ep: | ||
| num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 | ||
| else: | ||
| num_warps = 8 |
There was a problem hiding this comment.
The FIXME comment on line 301 indicates that num_warps should be adjusted based on the batch size. The current implementation determines num_warps based on the static environment variable VLLM_MOE_DP_CHUNK_SIZE, which may not reflect the dynamic batch size at runtime. This static configuration could lead to suboptimal performance or potential correctness issues if the Triton kernel has strict requirements for num_warps based on the input size. This value is used during weight loading to swizzle the weights, so it cannot be changed dynamically per batch without re-swizzling. This suggests a potential design issue that should be addressed for robust performance and correctness.
| if current_platform.is_cuda() and \ | ||
| current_platform.is_device_capability(100): | ||
| constraints = { | ||
| "is_persistent": True, | ||
| "epilogue_subtile": 1, | ||
| } | ||
| opt_flags.update_opt_flags_constraints(constraints) |
There was a problem hiding this comment.
The function _swizzle_mxfp4 modifies a global state by calling opt_flags.update_opt_flags_constraints(constraints). Modifying global state within a utility function is a dangerous pattern as it can introduce non-local side effects that are difficult to debug, especially in a system that might handle multiple models or requests concurrently. This could cause issues if different models or layers have conflicting requirements for these optimization flags. It would be safer to manage this global state with more care, for example, by using a context manager to set and restore the flags, or by passing constraints as parameters to the underlying kernel if the API supports it.
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
| def has_triton_kernels() -> bool: | ||
| """Whether the optional `triton_kernels` package is available.""" | ||
|
|
||
| return _has_module("triton_kernels") |
There was a problem hiding this comment.
QQ: How can I install this?
There was a problem hiding this comment.
We need to install directly from triton repo
uv pip install triton/python/triton_kernels --no-deps
There's no PyPI wheel yet
|
hmm, this broke the trunk |
|
Are you running You need to install git clone https://github.com/triton-lang/triton
uv pip install triton/python/triton_kernels --no-deps |
|
Pushed a fix #22529 |
|
Just FYI that the error shows up on llama4 benchmark run https://github.com/pytorch/pytorch-integration-testing/actions/runs/16834994069/job/47692144587#step:14:3962, so it's other models too |
|
yea I was running deepseek. The code path is shared. Thanks for the quick fix! |
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Noam Gat <noamgat@gmail.com>
|
Hi @zyongye , can we use that kernel on Blackwell? If so, could you provide the Triton commit? I encountered the following issue when running UT locally. |
|
Hi @yiliu30 for Blackwell SM100 we have kernels from flashinfer available, please see the recipe for details https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html#b200 |
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
| quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), | ||
| value_layout, **value_layout_opts) | ||
| scale = convert_layout(wrap_torch_tensor(scale), scale_layout, | ||
| **scale_layout_opts) | ||
| return quant_tensor, InFlexData(), scale |
There was a problem hiding this comment.
Is it safe to unwrap from triton_kernels.tensor.Tensor from here? Could we avoid it in the first place?
There was a problem hiding this comment.
This is util function from triton_kernels. It is designed to take triton_kernels.Tensor instead of torch.Tensor.
| del layer.w2_weight | ||
| layer.w13_weight = None | ||
| layer.w2_weight = None | ||
| torch.cuda.empty_cache() |
Need nightly torch and triton main to work.
Don't merge. want for accuracy test